# coding=utf-8
# Copyright 2019 The SEED Authors
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""Implements generalized advantage estimation.

The implementation is based on the paper:
High-Dimensional Continuous Control Using Generalized Advantage Estimation
https://arxiv.org/abs/1506.02438
"""

import abc
import gin
import tensorflow as tf


def vtrace(values, rewards, done_terminated, done_abandoned, discount_factor,
           target_action_log_probs, behaviour_action_log_probs, lambda_=1.0,
           max_importance_weight=1., name='vtrace'):
  r"""Calculates V-trace value targets and advantages.

  Args:
    values: A float32 tensor of shape [T+1, B] with the value function estimates
      wrt. the target policy, i.e., for the time steps, i, i+1, ..., i+T.
    rewards: A float32 tensor of shape [T, B] containing rewards generated by
      following the behaviour policy after time steps i, i+1, ..., i+T-1.
    done_terminated: A boolean tensor of shape [T, B] signifying if the agent
      terminated after the actions in steps i, i+1, ..., i+T-1. This is
      equivalent to going into a terminal state with infinite rewards of 0.
    done_abandoned: A boolean tensor of shape [T, B] signifying if the agent did
      not further act after the actions in steps i, i+1, ..., i+T-1. This is not
      the same as termination and can be used if the maximum episode length is
      reached. This will set the advantage of that state to zero and the target
      value to the input value (which generally results in a zero gradient).
    discount_factor: Float with the discount factor to be used.
    target_action_log_probs: A float32 tensor of shape [T, B] with
      log-probabilities of taking the action by the current policy.
    behaviour_action_log_probs: A float32 tensor of shape [T, B] with
      log-probabilities of taking the action by the behavioural policy.
    lambda_: Float that determines the mix between 1-step (lambda_=0) and n-step
      (lambda_=1) bootstrapping
    max_importance_weight: Bigger importance weights are clipped.
    name: String with the name scope that all operations will be created in.

  Returns:
    A float32 tensor of shape [T, B] with value targets that can be used to
      train a baseline (V(x_t) - vs_t)^2.
    A float32 tensor of shape [T, B] of advantages.
  """
  with tf.name_scope(name):
    # Compute importance sampling weights.
    log_rhos = target_action_log_probs - behaviour_action_log_probs
    log_rhos = tf.minimum(log_rhos, tf.math.log(max_importance_weight))
    rhos = tf.exp(log_rhos)

    # We compute the temporal differences with special handling of episodes
    # which ended. We consider two cases:
    # - Termination: In this case, the agent took a decision that led to proper
    #     termination of the episode. In this case, the future value of the
    #     policy is enforced to be zero. This is done by setting the next step
    #     bootstrapping value to zero and not to the next value function (which
    #     is the value of the state after the reset).
    not_terminated_mask = tf.cast(~done_terminated, tf.float32)
    next_step_bootstrap = not_terminated_mask * values[1:]

    # - Abandonment: The current episode was abandoned, e.g., due to a maximum
    #     epsiode length. If the policy would have continued, it would have
    #     continued to obtain rewards. We handle this by setting the temporal
    #     difference and thus the advantage to zero.
    not_abandoned_mask = tf.cast(~done_abandoned, tf.float32)
    deltas = rewards + discount_factor * next_step_bootstrap - values[:-1]
    deltas *= not_abandoned_mask

    # For both cases, we do not propagate future temporal differences as they
    # relate to different episodes.
    propagate_future = not_terminated_mask * not_abandoned_mask

    # We accumulate temporal differences by iterating backwards in time and
    # computing advantages as we go using dynamic programming.
    accumulator = tf.zeros_like(values[0])
    targets = []
    advantages = []
    for i in range(int(rewards.shape[0]) - 1, -1, -1):
      future = propagate_future[i] * discount_factor * lambda_ * accumulator
      # For advantages we don't use importance weights because this is
      # the advantage exactly for the action which was taken.
      advantages.append(deltas[i] + future)
      # On the other hand, the accumulator corresponds to the value for the
      # current state so both terms are multiplied by rho.
      accumulator = rhos[i] * (deltas[i] + future)
      targets.append(values[i] + accumulator)

    # We need to return targets and values with stopped gradients, as we do not
    # want to differentiate through the generalized advantage estimator.
    targets = tf.convert_to_tensor(targets[::-1], dtype=tf.float32)
    advantages = tf.convert_to_tensor(advantages[::-1], dtype=tf.float32)
    return tf.stop_gradient(targets), tf.stop_gradient(advantages)




def gae(values, rewards, done_terminated, done_abandoned, discount_factor,
        target_action_log_probs=None, behaviour_action_log_probs=None,
        lambda_=1.0, name='gae'):
  """Generalized Advantages Estimator.

  Args:
    See V-trace above.

  Returns:
    A float32 tensor of shape [T, B] with value targets that can be used to
      train a baseline (V(x_t) - vs_t)^2.
    A float32 tensor of shape [T, B] of advantages.
  """
  return vtrace(values, rewards, done_terminated, done_abandoned,
                discount_factor,
                tf.zeros_like(rewards), tf.zeros_like(rewards),
                lambda_, 1., name)


class AdvantageEstimator(tf.Module, metaclass=abc.ABCMeta):
  """Abstract base class for advantage estimators."""

  @abc.abstractmethod
  def __call__(self, values, rewards, done_terminated, done_abandoned,
               discount_factor, target_action_log_probs,
               behaviour_action_log_probs):
    r"""Computes advantages and value function targets.

    Args:
      values: A float32 tensor of shape [T+1, B] with the value function
        estimates wrt. the target policy, i.e., for the time steps, i, i+1,
        ..., i+T.
      rewards: A float32 tensor of shape [T, B] containing rewards generated by
        following the behaviour policy after time steps i, i+1, ..., i+T-1.
      done_terminated: A boolean tensor of shape [T, B] signifying if the agent
        terminated after the actions in steps i, i+1, ..., i+T-1. This is
        equivalent to going into a terminal state with infinite rewards of 0.
      done_abandoned: A boolean tensor of shape [T, B] signifying if the agent
        did not further act after the actions in steps i, i+1, ..., i+T-1. This
        is not the same as termination and can be used if the maximum episode
        length is reached. This will set the advantage of that state to zero and
        the target value to the input value (which generally results in a zero
        gradient).
      discount_factor: Float with the discount factor to be used.
      target_action_log_probs: A float32 tensor of shape [T, B] with
        log-probabilities of taking the action by the current policy
      behaviour_action_log_probs: A float32 tensor of shape [T, B] with
        log-probabilities of taking the action by the behavioural policy


    Returns:
      A float32 tensor of shape [T, B] with value targets that can be used to
        train a baseline (V(x_t) - vs_t)^2.
      A float32 tensor of shape [T, B] of advantages.
    """
    raise NotImplementedError('`__call__()` is not implemented!')


@gin.configurable
class GAE(AdvantageEstimator):

  def __init__(self, lambda_, name='gam'):
    super().__init__()
    self.lambda_ = lambda_

  def __call__(self, *args, **kwargs):
    return gae(*args, **kwargs, lambda_=self.lambda_, name=self.name)


@gin.configurable
class VTrace(AdvantageEstimator):

  def __init__(self, lambda_, max_importance_weight=1., name='vtrace'):
    super().__init__(name)
    self.lambda_ = lambda_
    self.max_importance_weight = max_importance_weight

  def __call__(self, *args, **kwargs):
    return vtrace(*args, **kwargs,
                  max_importance_weight=self.max_importance_weight,
                  lambda_=self.lambda_, name=self.name)


@gin.configurable
class NStep(AdvantageEstimator):
  """N-step returns."""

  def __init__(self, n, name='nstep2'):
    super().__init__(name)
    self.n = n

  def __call__(self, values, rewards, done_terminated, done_abandoned,
               discount_factor, target_action_log_probs,
               behaviour_action_log_probs):
    with tf.name_scope(self.name):
      # We compute the n-step returns in min(n, unroll_length) steps.
      unroll_length = int(rewards.shape[0])
      eff_n = self.n if self.n < unroll_length else unroll_length

      # We pad the dimension with n-1 additional values with abandon=True so
      # that we don't have to handle the last n-1 steps differently.
      values_pad = tf.zeros((eff_n - 1, values.shape[1]), dtype=tf.float32)
      done_terminated_pad = tf.zeros((eff_n - 1, values.shape[1]),
                                     dtype=tf.bool)
      done_abandoned_pad = tf.ones((eff_n - 1, values.shape[1]), dtype=tf.bool)
      rewards_pad = tf.zeros((eff_n - 1, values.shape[1]), dtype=tf.float32)

      nvalues = tf.concat([values, values_pad], axis=0)
      ndone_terminated = tf.concat([done_terminated, done_terminated_pad],
                                   axis=0)
      ndone_abandoned = tf.concat([done_abandoned, done_abandoned_pad], axis=0)
      nrewards = tf.concat([rewards, rewards_pad], axis=0)

      future_value = nvalues[eff_n:]

      window_size = rewards.shape[0]

      for i in range(eff_n):
        # Extract relevant sub tensors.
        start = eff_n - i - 1
        end = start + window_size
        rel_n_values = nvalues[start:end]
        rel_rewards = nrewards[start:end]
        rel_done_terminated = ndone_terminated[start:end]
        rel_done_abandoned = ndone_abandoned[start:end]

        # We compute the targets with special handling of episodes
        # which ended. We consider two cases:
        # - Termination: In this case, the agent took a decision that led to
        #     proper termination of the episode. In this case, the future value
        #     of the policy is enforced to be zero. This is done by setting the
        #     next step bootstrapping value to zero and not to the next value
        #     function (which is the value of the state after the reset).
        not_terminated_mask = tf.cast(~rel_done_terminated, tf.float32)
        next_step_bootstrap = not_terminated_mask * future_value

        # - Abandonment: The current episode was abandoned, e.g., due to a
        #     maximum episode length (or padding). If the policy would have
        #     continued, it would have continued to obtain rewards. We handle
        #     this by setting the value to the current value.
        not_abandoned_mask = tf.cast(~rel_done_abandoned, tf.float32)
        abandoned_mask = tf.cast(rel_done_abandoned, tf.float32)

        one_step_bootstrap = rel_rewards + discount_factor * next_step_bootstrap

        future_value = (not_abandoned_mask*one_step_bootstrap +
                        abandoned_mask*rel_n_values)

      advantages = future_value - values[:-1]
      return tf.stop_gradient(future_value), tf.stop_gradient(advantages)
